import os
import re
import logging
import pickle
import numpy as np
from typing import Callable, Dict, Iterable, List, Tuple, Optional

import torch.nn.functional as F

from transformers import (
    GPT2Tokenizer,
    GPT2TokenizerFast,
)
from rouge_score import rouge_scorer, scoring

from filelock import FileLock
try:
    import nltk

    NLTK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
    NLTK_AVAILABLE = False

if NLTK_AVAILABLE:
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)

logger = logging.getLogger(__name__)

SPECIAL_TOKENS = ["<bos>", "<|endoftext|>", "<speaker1>", "<speaker2>", "<pad>"]
ATTR_TO_SPECIAL_TOKEN = {'bos_token': '<bos>', 'pad_token': '<pad>',
                         'additional_special_tokens': ['<speaker1>', '<speaker2>']}
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]

def add_special_tokens_(model, tokenizer):
    """ Add special tokens to the tokenizer and the model if they have not already been added. """
    orig_num_tokens = len(tokenizer) #len(tokenizer.encoder)
    # print(len(tokenizer))
    if isinstance(tokenizer, GPT2Tokenizer) or isinstance(tokenizer, GPT2TokenizerFast):
        ATTR_TO_SPECIAL_TOKEN.update({'eos_token': '<|endoftext|>'})
    num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) # doesn't add if they are already there
    # print(num_added_tokens)
    # print(len(tokenizer))
    if num_added_tokens > 0:
        # print(model.transformer.wte.weight.size())
        model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)
        # print(model.transformer.wte.weight.size())
        # print(model.transformer.wpe.weight.size())

def add_special_tokens_for_qa(model, tokenizer):
    """ Add special tokens to the tokenizer and the model if they have not already been added. """
    orig_num_tokens = len(tokenizer) #len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': ['<speaker1>', '<speaker2>']}) # doesn't add if they are already there
    if num_added_tokens > 0:
        model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)


def remove_v_head(info):
    return [item for item in info if not item.startswith("v_head")]


def save(toBeSaved, filename, mode='wb'):
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    file = open(filename, mode)
    pickle.dump(toBeSaved, file, protocol=4) # protocol 4 allows large size object, it's the default since python 3.8
    file.close()


def add_newline_to_end_of_each_sentence(x: str) -> str:
    """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
    re.sub("<n>", "", x)  # remove pegasus newline char
    assert NLTK_AVAILABLE, "nltk must be installed to separate newlines between sentences. (pip install nltk)"
    return "\n".join(nltk.sent_tokenize(x))

# Evaluation metric related
def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
    return list(map(f, x))

def calculate_bleu(output_lns, refs_lns) -> dict:
    """Uses multi-bleu-detok.perl implementation."""
    tmp = os.popen(f'perl multi-bleu-detok.perl {refs_lns} < {output_lns}').readlines()
    tmp = tmp[-1].strip()
    m = re.match(r'BLEU = (?P<bleu>.*), (?P<bleu1>.*)/(?P<bleu2>.*)/(?P<bleu3>.*)/(?P<bleu4>.*) \(BP=(?P<BP>.*), ratio=(?P<ratio>.*), hyp_len=(?P<hyp_len>.*), ref_len=(?P<ref_len>.*)\)', tmp)
    return {"bleu": round(float(m.group('bleu')), 2), "bleu4": round(float(m.group('bleu4')), 2), \
        "bleu3": round(float(m.group('bleu3')), 2), "bleu2": round(float(m.group('bleu2')), 2), \
        "bleu1": round(float(m.group('bleu1')), 2)}


def extract_rouge_mid_statistics(dct):
    new_dict = {}
    for k1, v1 in dct.items():
        mid = v1.mid
        new_dict[k1] = {stat: round(getattr(mid, stat) * 100, 2) for stat in ["precision", "recall", "fmeasure"]}
    return new_dict

def calculate_rouge(
    pred_lns: List[str],
    tgt_lns: List[str],
    use_stemmer=True,
    rouge_keys=ROUGE_KEYS,
    return_precision_and_recall=False,
    bootstrap_aggregation=True,
    newline_sep=True,
) -> Dict:
    """Calculate rouge using rouge_scorer package.

    Args:
        pred_lns: list of summaries generated by model
        tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
        use_stemmer:  Bool indicating whether Porter stemmer should be used to
        strip word suffixes to improve matching.
        rouge_keys:  which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
        return_precision_and_recall: (False) whether to also return precision and recall.
        bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
            this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
        newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
        on multi sentence summaries (CNN/DM dataset).

    Returns:
         Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys

    """
    scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
    aggregator = scoring.BootstrapAggregator()
    for pred, tgt in zip(tgt_lns, pred_lns):
        # rougeLsum expects "\n" separated sentences within a summary
        if newline_sep:
            pred = add_newline_to_end_of_each_sentence(pred)
            tgt = add_newline_to_end_of_each_sentence(tgt)
        scores = scorer.score(pred, tgt)
        aggregator.add_scores(scores)

    if bootstrap_aggregation:
        result = aggregator.aggregate()
        if return_precision_and_recall:
            return extract_rouge_mid_statistics(result)  # here we return dict
        else:
            return {k: round(v.mid.fmeasure * 100, 2) for k, v in result.items()}

    else:
        return aggregator._scores  # here we return defaultdict(list)



# def build_compute_metrics_fn(tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
#     def non_pad_len(tokens: np.ndarray) -> int:
#         return np.count_nonzero(tokens != tokenizer.pad_token_id)

#     def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
#         pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
#         label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
#         pred_str = lmap(str.strip, pred_str)
#         label_str = lmap(str.strip, label_str)
#         return pred_str, label_str
    
#     def summarization_metrics(pred: EvalPrediction) -> Dict:
#         pred_str, label_str = decode_pred(pred)
#         rouge: Dict = calculate_rouge(pred_str, label_str)
#         summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
#         rouge.update({"gen_len": summ_len})
#         return rouge

#     def translation_metrics(pred: EvalPrediction) -> Dict:
#         pred_str, label_str = decode_pred(pred)
#         bleu: Dict = calculate_bleu(pred_str, label_str)
#         gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
#         bleu.update({"gen_len": gen_len})
#         return bleu

#     compute_metrics_fn = summarization_metrics
#     return compute_metrics_fn

def build_compute_metrics_fn_gpt2(task, nlgeval=None):    
    def rouge_metrics(pred, gold) -> Dict:
        rouge: Dict = calculate_rouge(pred, gold)
        return rouge
    
    def rouge1_recall_metrics(pred, gold) -> Dict:
        results = calculate_rouge(pred, gold, rouge_keys=["rouge1"], return_precision_and_recall=True)
        metrics = {"rouge1_recall": results["rouge1"]["recall"]}
        metrics.update(calculate_rouge(pred, gold))
        rouge1_recall: Dict = metrics
        return rouge1_recall

    def nlgeval_bleu_metrics(pred, gold) -> Dict:
        assert nlgeval is not None
        bleu: Dict = nlgeval.compute_metrics(gold, pred)
        return bleu
    
    def bleu_metrics(pred, gold) -> Dict:
        bleu: Dict = calculate_bleu(pred, gold)
        return bleu

    metrics = {
        "nlgeval_bleu": nlgeval_bleu_metrics,
        "bleu": bleu_metrics,
        "rouge": rouge_metrics,
        "rouge1_recall": rouge1_recall_metrics
    }
    compute_metrics_fn = metrics[task]
    return compute_metrics_fn


def postprocess_gpt2_predictions(
    predictions: Tuple[np.ndarray, np.ndarray],
    labels,
    tokenizer, 
    output_dir: Optional[str] = None,
    prefix: Optional[str] = None,
    is_world_process_zero: bool = True,
):
    # Logging.
    logger.setLevel(logging.INFO)
    assert len(predictions) == len(labels), f"Got {len(predictions[0])} predictions and {len(labels)} examples."
    
    golds = []
    preds = []
    for prediction, label in zip(predictions, labels):
        print("prediction", prediction)
        print("label", label)

        # dddd
        pred = tokenizer.decode(prediction, skip_special_tokens=True)
        gold = tokenizer.decode(label, skip_special_tokens=True)
        print("generated_tokens", pred, prediction.shape)
        print("labels", gold, label.shape)

        input()

        preds.append(pred)
        golds.append(gold)
    
    with open(os.path.join(output_dir, "predictions.txt"), "r") as f:
        for line in preds:
            f.write(line+"\n")
    with open(os.path.join(output_dir, "references.txt"), "r") as f:
        for line in golds:
            f.write(line+"\n")
    return preds, golds